package com.feidee.fd.sml.algorithm.component.ml.ps

import com.feidee.fd.sml.algorithm.component.ml.MLParam
import com.feidee.fd.sml.algorithm.util.ToolClass
import com.tencent.angel.sona.core.DriverContext
import com.tencent.angel.sona.ml.{PipelineModel, PipelineStage}
import org.apache.spark.linalg.Vectors
import org.apache.spark.ml.classification.ConvertPSModel2Spark
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.{DataFrame, SPKSQLUtils, functions}


abstract class AbstractPSMLComponent[A <: MLParam] (implicit m:Manifest[A]) extends AbstractPSComponent[A] with Serializable {

  /**
    * 根据参入参数设置算法组件实体
    * @param param 运行参数
    * @return      算法组件
    */
  def setUp(param: A): PipelineStage = ???

  /**
    * 训练模型
    * @param param 算法运行参数
    * @param data  输入（训练用）数据
    * @return      算法模型
    */
  def train(param: A, data: DataFrame): PipelineModel = ???

  /**
    * 根据参数，评估指标
    * @param param    模型训练时的参数
    * @param model    生成的模型（可以从模型里拿相关信息辅助指标计算）
    * @param data     评估用的 DataFrame
    * @return         指标值
    */
  def calculateMetrics(param: A, model: org.apache.spark.ml.PipelineModel, data: DataFrame): Map[String, Double] = ???

  /**
    * 保存模型，除了要保存一份到参数指定路径，还需要在该路径下保存一份以时间作为后缀的模型作为备份及供线上使用
    * @param model  要保存的模型
    * @param param  参数
    */
  def outputPSModel(model: PipelineModel, param: A): Unit = {
    model.write.overwrite().save(param.modelPath+"_PS")
    if (tool.isNotNull(param.flow_time)) {
      model.write.overwrite().save(s"${param.modelPath+"_PS"}_${param.flow_time}")
    } else {
      logWarning("未发现运行时间参数，不做模型备份处理")
    }
  }
  def outputModel(model: org.apache.spark.ml.PipelineModel, param: A): Unit = {
    model.write.overwrite().save(param.modelPath)
    if (tool.isNotNull(param.flow_time)) {
      model.write.overwrite().save(s"${param.modelPath}_${param.flow_time}")
    } else {
      logWarning("未发现运行时间参数，不做模型备份处理")
    }
  }

  def convertVector(dataset:DataFrame,isSparse:Boolean,featureCol:String):DataFrame = {
    SPKSQLUtils.registerUDT()
    val convert =(sp:Boolean) =>udf{vec:org.apache.spark.ml.linalg.Vector =>{
      if(sp){
        val spvec = vec.toSparse
        Vectors.sparse(spvec.size,spvec.indices,spvec.values)
      }else{
        Vectors.dense(vec.toArray)
      }
    }
    }
    val exsitField = dataset.schema.fieldNames
    val tmpFeatName = new ToolClass().renameDuplicatedColName(featureCol,exsitField)

    val convertedData = dataset.withColumn(tmpFeatName,convert(isSparse)(dataset(featureCol)))
    val dropedData = convertedData.drop(featureCol)
    dropedData.withColumnRenamed(tmpFeatName,featureCol)
  }
  /**
    * 保存指标
    * @param metrics  保存的指标集合
    * @param dir      指标文件存放目录
    * @return         是否都保存成功
    */
  @Deprecated
  def outputMetrics(metrics: Map[String, Double], dir: String): Boolean = {
    var isMetricsAllSaved = true
    // 在存放目录下写指标值入对应指标文件名
    metrics.foreach{case (name, value) =>
      isMetricsAllSaved &= tool.writeToHDFS(s"$dir/$name", value.toString)
    }
    isMetricsAllSaved
  }

  /**
    * 把计算的指标添加到结果 DataFrame 后（表地址就是设置的 hive_table）
    * @param metrics  指标值
    * @param data     结果 DataFrame
    * @return         追加指标值后的结果 DataFrame
    */
  def appendMetrics(metrics: Map[String, Double], data: DataFrame): DataFrame = {
    metrics.foldLeft(data)((cur, m) => cur.withColumn(m._1, functions.lit(m._2)))
  }

  /**
    * 执行过程
    * @param paramStr 传过来的参数
    */
  override def apply(paramStr: String): Unit = {
    logInfo("parsing parameter")
    val param = parseParam(paramStr)
    logInfo("verifying parameter")
    param.verify()
    logInfo(s"loading input data FROM ${param.input_pt}")
    val inputData = loadData(param)
    val data = convertVector(inputData,true,param.featuresCol)
    logInfo(s"app_list type:${data.schema(param.featuresCol).dataType.catalogString}")
    logInfo("training ml model")
    val driverContext = DriverContext.get(spark.sparkContext.getConf)
    driverContext.startAngelAndPSAgent()
    var model = try{
      train(param, data)
    }catch {
      case e:Exception=>{
        driverContext.stopAngelAndPSAgent()
        throw new Exception(e.getMessage)
      }
    }

    logInfo(s"saving ps model TO ${param.modelPath+"_PS"}")
    outputPSModel(model, param)

//    val pipelineModel = PipelineModel.load(param.modelPath+"_PS")
    val sparkModel = ConvertPSModel2Spark.convertPipeline(spark,param.modelPath+"_PS",model)
    driverContext.stopAngelAndPSAgent()

    var result = sparkModel.transform(inputData)
    logInfo(s"fields:${result.schema.fieldNames.mkString(",")}")
    logInfo(s"result:${result.collect().head.toString()}")
    outputModel(sparkModel,param)
//    var result = model.transform(data)
    // 保存
    if (tool.isNotNull(param.output_pt)) {
      logInfo(s"saving ml predicted result TO ${param.output_pt}")
      outputData(result, param)
    }

    // 另存
    if (tool.isNotNull(param.hive_table)) {
      // 要计算的指标不为空
      if (param.metrics.length > 0) {
        logInfo("calculating metrics")
        val metrics = calculateMetrics(param, sparkModel, result)
        // 追加指标值到结果中
        logInfo(s"appending metrics ${metrics.keySet.mkString("[", ", ", "]")}")
        result = appendMetrics(metrics, result)
      }
      logInfo(s"saving ml predicted result to hive table ${param.hive_table}")
      outputTable(result, param)
    }

  }

}
